---
title: "External Validity: Framework, Design and Analysis"
subtitle: "Online Supplementary Materials #2 -- Simulations"
author: "Naoki Egami and Erin Hartman"
date: \today
output:
  html_document:
    df_print: paged
---

```{r setup, include=FALSE}
rm(list = ls())
knitr::opts_chunk$set(echo = TRUE)
knitr::opts_chunk$set(warning = FALSE)
knitr::opts_chunk$set(message = FALSE)
```

```{r load_libraries, include = FALSE}
library(tidyverse)
library(dplyr)
library(kableExtra)
```

## System Information
```{r}
print(Sys.time())

start_time <- Sys.time()

print(sessionInfo())
```

```{r}
if(!file.exists("./generated/appendix_2")) dir.create("./generated/appendix_2")
if(!file.exists("./generated/Figures")) dir.create("./generated/Figures")
if(!file.exists("./generated/Tables")) dir.create("./generated/Tables")
```

Load results.
```{r}
load("./generated/appendix_2/full_sims_nonlinear_samp_correct.RData")
load("./generated/appendix_2/full_sims_nonlinear_samp_incorrect.RData")
load("./generated/appendix_2/full_sims_linear_samp_correct.RData")
load("./generated/appendix_2/full_sims_linear_samp_incorrect.RData")

full_sims_results <- bind_rows(full_sims_nonlinear_samp_correct,
                               full_sims_nonlinear_samp_incorrect,
                               full_sims_linear_samp_correct,
                               full_sims_linear_samp_incorrect)
```

Format the results
```{r}
full_sims_results <- full_sims_results %>% 
  group_by(correct_sample, correct_model, n) %>%
  mutate(spTPATE = mean(est[estimator == "PATE"])) %>%
  group_by(correct_sample, correct_model, n, sim) %>%
  mutate(PATE = est[estimator == "PATE"],
         Bias_fs = est - PATE) %>%
  ungroup() %>%
  mutate(estimator_name = factor(case_when(estimator == "SATE" ~ "Diff-in-\nMeans",
                                            estimator == "IPW-logit" ~ "IPW",
                                            estimator == "wLS-logit" ~ "wLS",
                                            estimator == "OLS-proj" ~ "OLS",
                                            estimator == "BART-proj" ~ "BART",
                                            estimator == "DR-OLS-logit" ~ "AIPW\nwith OLS",
                                            estimator == "DR-BART-logit" ~ "AIPW\nwith BART",
                                            estimator == "PATE" ~ "PATE",
                                            estimator == "fs_PATE" ~ "PATE (fs)"), levels = c("Diff-in-\nMeans", "IPW", "wLS", "OLS", "BART", "AIPW\nwith OLS", "AIPW\nwith BART", "wLS\nProjection", "PATE", "PATE (fs)")),
         estimator_type = factor(case_when(type == "SATE" ~ "SATE",
                                           type == "weighting" ~ "T-PATE:\nWeighting-based Estimator",
                                           type == "outcome" ~ "T-PATE:\nOutcome-based Estimator",
                                           type == "PATE" ~ "PATE",
                                           type == "doubly robust" ~ "T-PATE:\nDoubly Robust Estimator"),
                                 levels = c("SATE", "PATE", "T-PATE:\nWeighting-based Estimator", "T-PATE:\nOutcome-based Estimator", "T-PATE:\nDoubly Robust Estimator")),
         correct_outcome = case_when(correct_model == "lm" ~ "Linear",
                                     correct_model == "bart" ~ "Non-linear"),
         correct_sample = factor(correct_sample, levels = c("Yes", "No")))

save(full_sims_results, file = "./generated/appendix_2/full_sims.RData")
```

Plot results.

```{r, echo = FALSE, fig.width = 12, fig.height = 8}
make_plot <- function(g) {
  g +
  geom_boxplot() +
  geom_hline(aes(yintercept = spTPATE), linetype = "dashed") +
  xlab("Estimator") +
  ylab("Estimate") +
  theme(axis.text.x = element_text(angle = 90)) + 
  # annotate("text", x = 1, y = -1.75, label = "TPATE") +
  facet_grid(n ~ estimator_type, scales = "free_x", space="free_x") + theme_bw()
}

make_plot(ggplot(data = full_sims_results %>% filter(estimator_type != "PATE" & correct_outcome == "Linear") , aes(x = estimator_name, y = est, color = correct_sample))) +
  ggtitle("Simulation Results -- Linear Outcome") + xlab("") + scale_y_continuous(labels = scales::number_format(accuracy = 0.1)) +
  labs(color = "Correct Sampling Model", linetype = "True Outcome Model") + theme(legend.position = "none") + scale_color_manual(values=c("#8237FF", "#00cc6c"))

ggsave("./generated/Figures/Figure_A8_full_sim_results_linear.pdf", width = 8, height = 6)

make_plot(ggplot(data = full_sims_results %>% filter(estimator_type != "PATE" & correct_outcome == "Non-linear") , aes(x = estimator_name, y = est, color = correct_sample))) +
  ggtitle("Simulation Results -- Non-linear Outcome") + xlab("") + scale_y_continuous(labels = scales::number_format(accuracy = 0.1)) +
  labs(color = "Correct Sampling Model", linetype = "True Outcome Model") + theme(legend.position = "bottom") + scale_color_manual(values=c("#8237FF", "#00cc6c")) 

ggsave("./generated/Figures/Figure_A8_full_sim_results_nonlinear.pdf", width = 8, height = 6)
```

Make numerics table.

```{r numerics_figure_A8, include = FALSE}
## Outputs numeric code to "numerics_appendix_2.tex"

full_sims_numerics_linear <- full_sims_results %>% filter(estimator_type != "PATE" & correct_outcome == "Linear") %>%
    mutate(estimator_name = as.character(estimator_name),
           estimator_name = str_replace(estimator_name, "\n", " "),
           estimator_name = str_replace(estimator_name, "Diff-in- Means", "Diff-in-Means"),
           estimator_name = factor(estimator_name, 
                                   levels = c("Diff-in-Means", "IPW", "wLS", "OLS", "BART", "AIPW with OLS", "AIPW with BART", "wLS Projection", "PATE", "PATE (fs)"))) %>%
    group_by(correct_outcome, correct_sample, estimator_name, estimator_type, n) %>%
  summarize(Bias_lin = mean(est - spTPATE),
            SE_lin = sd(est),
            MSE_lin = mean((est - spTPATE)^2),
            IL_lin = mean(ci_upper - ci_lower)) %>%
  ungroup() %>%
  arrange(desc(correct_sample), estimator_name, n) %>%
  select(-correct_outcome, -estimator_type)

full_sims_numerics_nonlinear <- full_sims_results %>% filter(estimator_type != "PATE" & correct_outcome != "Linear") %>%
    mutate(estimator_name = as.character(estimator_name),
           estimator_name = str_replace(estimator_name, "\n", " "),
           estimator_name = str_replace(estimator_name, "Diff-in- Means", "Diff-in-Means"),
           estimator_name = factor(estimator_name, 
                                   levels = c("Diff-in-Means", "IPW", "wLS", "OLS", "BART", "AIPW with OLS", "AIPW with BART", "wLS Projection", "PATE", "PATE (fs)"))) %>%
    group_by(correct_outcome, correct_sample, estimator_name, estimator_type, n) %>%
  summarize(Bias_nonlin = mean(est - spTPATE),
            SE_nonlin = sd(est),
            MSE_nonlin = mean((est - spTPATE)^2),
            IL_nonlin = mean(ci_upper - ci_lower)) %>%
  ungroup() %>%
  arrange(desc(correct_sample), estimator_name, n) %>%
  select(-correct_outcome, -estimator_type)

full_sims_numerics <- merge(full_sims_numerics_linear, full_sims_numerics_nonlinear) %>%
  arrange(correct_sample, estimator_name, n)

cat(
  full_sims_numerics %>% 
      select(estimator_name, n, Bias_lin, SE_lin, MSE_lin, IL_lin,
             Bias_nonlin, SE_nonlin, MSE_nonlin, IL_nonlin) %>%
      kable(caption = "Numeric Values for Simulations in Figure A8.",
        booktabs = T,
        col.names = c("Estimator", "n", "Bias", "SE", "MSE", "Avg. Interval Length",
                      "Bias", "SE", "MSE", "Avg. Interval Length"),
        digits = 3,
        format = "latex"
        ) %>%
  kable_styling(latex_options = c("scale_down", "hold_position")) %>%
    pack_rows("Correct Sampling Model", 
              min(which(full_sims_numerics$correct_sample == "Yes")),
              max(which(full_sims_numerics$correct_sample == "Yes")), latex_gap_space = "1em") %>%
    pack_rows("Incorrect Sampling Model", 
              min(which(full_sims_numerics$correct_sample == "No")),
              max(which(full_sims_numerics$correct_sample == "No")), latex_gap_space = "2em") %>%
    collapse_rows(columns = 1, latex_hline = "major", valign = "middle") %>%
    add_header_above(c(" " = 2, "Linear Outcome" = 4, "Non-Linear Outcome" = 4)),
  file = "./generated/Tables/Appendix_2_A8_numerics.tex")
```

# Section J.2 -- Naturalistic Simulations

Load results.
```{r}
load("./generated/appendix_2/naturalistic_sims_nonlinear_samp_correct.RData")
load("./generated/appendix_2/naturalistic_sims_nonlinear_samp_incorrect.RData")
load("./generated/appendix_2/naturalistic_sims_linear_samp_correct.RData")
load("./generated/appendix_2/naturalistic_sims_linear_samp_incorrect.RData")

full_sims_results_nat <- bind_rows(naturalistic_sims_nonlinear_samp_correct,
                               naturalistic_sims_nonlinear_samp_incorrect,
                               naturalistic_sims_linear_samp_correct,
                               naturalistic_sims_linear_samp_incorrect)
```


Recode results

```{r}
full_sims_results_nat <- full_sims_results_nat %>% 
  group_by(correct_sample, correct_model, samp_size, sim) %>%
  mutate(PATE = est[estimator == "PATE"][ifelse(estimator %in% c("BART-proj", "DR-BART-logit"), 1, 2)],
         PATE_fs = est[estimator == "fs_PATE"][ifelse(estimator %in% c("BART-proj", "DR-BART-logit"), 1, 2)],
         Bias = est - PATE,
         Bias_fs = est - PATE_fs) %>%
  ungroup() %>%
  group_by(correct_sample, correct_model, samp_size) %>%
  mutate(spTPATE = mean(PATE, na.rm = TRUE)) %>%
  ungroup() %>%
  mutate(estimator_name = factor(case_when(estimator == "SATE" ~ "Diff-in-\nMeans",
                                            estimator == "IPW-logit" ~ "IPW",
                                            estimator == "wLS-logit" ~ "wLS",
                                            estimator == "OLS-proj" ~ "OLS",
                                            estimator == "BART-proj" ~ "BART",
                                            estimator == "DR-OLS-logit" ~ "AIPW\nwith OLS",
                                            estimator == "DR-BART-logit" ~ "AIPW\nwith BART",
                                            estimator == "PATE" ~ "PATE",
                                            estimator == "fs_PATE" ~ "PATE (fs)"), levels = c("Diff-in-\nMeans", "IPW", "wLS", "OLS", "BART", "AIPW\nwith OLS", "AIPW\nwith BART", "wLS\nProjection", "PATE", "PATE (fs)")),
         estimator_type = factor(case_when(type == "SATE" ~ "SATE",
                                           type == "weighting" ~ "T-PATE:\nWeighting-based Estimator",
                                           type == "outcome" ~ "T-PATE:\nOutcome-based Estimator",
                                           type == "PATE" ~ "PATE",
                                           type == "doubly robust" ~ "T-PATE:\nDoubly Robust Estimator"),
                                 levels = c("SATE", "PATE", "T-PATE:\nWeighting-based Estimator", "T-PATE:\nOutcome-based Estimator", "T-PATE:\nDoubly Robust Estimator")),
         correct_outcome = case_when(correct_model == "ols" ~ "Linear",
                                     correct_model == "bart" ~ "Non-linear"),
         correct_sample = factor(correct_sample, levels = c("Yes", "No")))

save(full_sims_results_nat, file = "./generated/appendix_2/full_sims_natural.RData")
```

Plot results

```{r, echo = FALSE, fig.width = 10, fig.height = 6}
make_plot <- function(g) {
  g +
  geom_boxplot() +
  geom_hline(aes(yintercept = spTPATE), linetype = "dashed") +
  xlab("Estimator") +
  ylab("Estimate") +
  theme(axis.text.x = element_text(angle = 90)) + 
  # annotate("text", x = 1, y = -1.75, label = "TPATE") +
  facet_grid(samp_size ~ estimator_type, scales = "free_x", space="free_x") + theme_bw() +
    theme(legend.position = "bottom")
}

make_plot(ggplot(data = full_sims_results_nat %>% filter(estimator_type != "PATE" & correct_outcome == "Linear") , aes(x = estimator_name, y = est, color = correct_sample))) +
  ggtitle("Naturalistic Simulation Results -- Linear Outcome") + xlab("") + scale_y_continuous(labels = scales::number_format(accuracy = 0.1)) +
  labs(color = "Correct Sampling Model") + scale_color_manual(values=c("#8237FF", "#00cc6c")) + theme(legend.position = "none")

ggsave("./generated/Figures/Figure_A9_natural_sim_results_linear.pdf", width = 8, height = 6)

make_plot(ggplot(data = full_sims_results_nat %>% filter(estimator_type != "PATE" & correct_outcome == "Non-linear") , aes(x = estimator_name, y = est, color = correct_sample))) +
  ggtitle("Naturalistic Simulation Results -- Non-linear Outcome") + xlab("") + scale_y_continuous(labels = scales::number_format(accuracy = 0.1)) +
  labs(color = "Correct Sampling Model") + scale_color_manual(values=c("#8237FF", "#00cc6c"))

ggsave("./generated/Figures/Figure_A9_natural_sim_results_nonlinear.pdf", width = 8, height = 6)
```
Output numeric results.

```{r numerics_figure_A9, include = FALSE}
## Outputs numeric code to "numerics_appendix_2.tex"

full_sims_numerics_linear <- full_sims_results_nat %>% filter(estimator_type != "PATE" & correct_outcome == "Linear") %>%
    mutate(estimator_name = as.character(estimator_name),
           estimator_name = str_replace(estimator_name, "\n", " "),
           estimator_name = str_replace(estimator_name, "Diff-in- Means", "Diff-in-Means"),
           estimator_name = factor(estimator_name, 
                                   levels = c("Diff-in-Means", "IPW", "wLS", "OLS", "BART", "AIPW with OLS", "AIPW with BART", "wLS Projection", "PATE", "PATE (fs)"))) %>%
    group_by(correct_outcome, correct_sample, estimator_name, estimator_type, samp_size) %>%
  summarize(Bias_lin = mean(est - spTPATE),
            SE_lin = sd(est),
            MSE_lin = mean((est - spTPATE)^2),
            IL_lin = mean(ci_upper - ci_lower)) %>%
  ungroup() %>%
  arrange(desc(correct_sample), estimator_name, samp_size) %>%
  select(-correct_outcome, -estimator_type)

full_sims_numerics_nonlinear <- full_sims_results_nat %>% filter(estimator_type != "PATE" & correct_outcome != "Linear") %>%
    mutate(estimator_name = as.character(estimator_name),
           estimator_name = str_replace(estimator_name, "\n", " "),
           estimator_name = str_replace(estimator_name, "Diff-in- Means", "Diff-in-Means"),
           estimator_name = factor(estimator_name, 
                                   levels = c("Diff-in-Means", "IPW", "wLS", "OLS", "BART", "AIPW with OLS", "AIPW with BART", "wLS Projection", "PATE", "PATE (fs)"))) %>%
    group_by(correct_outcome, correct_sample, estimator_name, estimator_type, samp_size) %>%
  summarize(Bias_nonlin = mean(est - spTPATE),
            SE_nonlin = sd(est),
            MSE_nonlin = mean((est - spTPATE)^2),
            IL_nonlin = mean(ci_upper - ci_lower)) %>%
  ungroup() %>%
  arrange(desc(correct_sample), estimator_name, samp_size) %>%
  select(-correct_outcome, -estimator_type)

full_sims_numerics <- merge(full_sims_numerics_linear, full_sims_numerics_nonlinear) %>%
  arrange(correct_sample, estimator_name, samp_size)

cat(
  full_sims_numerics %>% 
      select(estimator_name, samp_size, Bias_lin, SE_lin, MSE_lin, IL_lin,
             Bias_nonlin, SE_nonlin, MSE_nonlin, IL_nonlin) %>%
      kable(caption = "Numeric Values for Simulations in Figure A9.",
        booktabs = T,
        col.names = c("Estimator", "n", "Bias", "SE", "MSE", "Avg. Interval Length",
                      "Bias", "SE", "MSE", "Avg. Interval Length"),
        digits = 3,
        format = "latex"
        ) %>%
  kable_styling(latex_options = c("scale_down", "hold_position")) %>%
    pack_rows("Correct Sampling Model", 
              min(which(full_sims_numerics$correct_sample == "Yes")),
              max(which(full_sims_numerics$correct_sample == "Yes")), latex_gap_space = "1em") %>%
    pack_rows("Incorrect Sampling Model", 
              min(which(full_sims_numerics$correct_sample == "No")),
              max(which(full_sims_numerics$correct_sample == "No")), latex_gap_space = "2em") %>%
    collapse_rows(columns = 1, latex_hline = "major", valign = "middle") %>%
    add_header_above(c(" " = 2, "Linear Outcome" = 4, "Non-Linear Outcome" = 4)),
  file = "./generated/Tables/Appendix_2_A9_numerics.tex", append = FALSE)
```


Total runtime is `r round(difftime(Sys.time(), start_time, units = "mins"), 2)` minutes.